{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.sys.path.append(os.path.dirname(os.path.abspath('.')))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据准备"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from datasets.dataset import load_boston\n",
    "data=load_boston()\n",
    "X,Y=data.data,data.target\n",
    "del data\n",
    "\n",
    "from model_selection.train_test_split import train_test_split\n",
    "X_train,X_test,Y_train,Y_test=train_test_split(X,Y,test_size=0.2)\n",
    "\n",
    "# print(X_train.shape,X_test.shape,Y_train.shape,Y_test.shape)\n",
    "\n",
    "# 把X，Y拼起来便于操作\n",
    "training_data=np.c_[X_train,Y_train]\n",
    "testing_data=np.c_[X_test,Y_test]\n",
    "\n",
    "# print(training_data.shape,testing_data.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 模型基础\n",
    "对于CART用做回归时，分裂的依据有很多的可选方法，这里使用MSE作为分裂依据。注意这里的MSE跟模型评估时的MSE不同，这里的MSE计算方式为：\n",
    "$$\n",
    "H(X_{m})=\\frac{1}{N_{m}}\\sum\\limits_{i{\\in}N_{m}}(y_{i}-\\bar{y}_{m})^{2}\n",
    "$$\n",
    "其中$\\bar{y}_{m}$为第$m$个叶节点中所有训练样本的目标值均值，详见[这里](https://scikit-learn.org/stable/modules/tree.html#regression-criteria)。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def MSE(data,y_idx=-1):\n",
    "    '''\n",
    "    返回数据集目标值的MSE\n",
    "    '''\n",
    "    N=len(data)\n",
    "    mean=np.mean(data[:,y_idx])\n",
    "    return np.sum(np.square(data[:,y_idx]-mean))/N\n",
    "\n",
    "# MSE(training_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义一个在指定特征与特征值下将数据集二分的函数，这里将小于等于分割值的数据集放入左分支，大于分割值的数据集放入右分支。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def BinSplitData(data,f_idx,f_val):\n",
    "    '''\n",
    "    以指定特征与特征值二分数据集\n",
    "    '''\n",
    "    data_left=data[data[:,f_idx]<=f_val]\n",
    "    data_right=data[data[:,f_idx]>f_val]\n",
    "    return data_left,data_right\n",
    "\n",
    "# SplitData(training_data,0,0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "分割函数与分割指标计算函数都有了，接下来就可以在数据集中迭代寻找最佳分割特征与特征值了。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Test(data, criteria='mse', min_samples_split=5, min_samples_leaf=5, min_impurity_decrease=0.0):\n",
    "    '''\n",
    "    对数据做test，找到最佳分割特征与特征值\n",
    "    return: best_f_idx, best_f_val，前者为空时代表叶节点，两者都为空时说明无法分裂\n",
    "    min_samples_split: 分裂所需的最小样本数，大于1\n",
    "    min_samples_leaf: 叶子节点的最小样本数，大于0\n",
    "    min_impurity_decrease: 分裂需要满足的最小增益\n",
    "    '''\n",
    "    n_sample, n_feature = data.shape\n",
    "\n",
    "    # 数据量小于阈值则直接返回叶节点，数据已纯净也返回叶节点\n",
    "    if n_sample < min_samples_split or len(np.unique(data[:,-1]))==1:    \n",
    "        return None, np.mean(data[:, -1])\n",
    "\n",
    "    MSE_before = MSE(data)    # 分裂前的MSE\n",
    "    best_gain = 0\n",
    "    best_f_idx = None\n",
    "    best_f_val = np.mean(data[:, -1])    # 默认分割值设为目标均值，当找不到分割点时返回该值作为叶节点\n",
    "\n",
    "    # 遍历所有特征与特征值\n",
    "    for f_idx in range(n_feature-1):\n",
    "        for f_val in np.unique(data[:, f_idx]):\n",
    "            data_left, data_right = BinSplitData(data, f_idx, f_val)    # 二分数据\n",
    "\n",
    "            # 分割后的分支样本数小于阈值则放弃分裂\n",
    "            if len(data_left) < min_samples_leaf or len(data_right) < min_samples_leaf:\n",
    "                continue\n",
    "\n",
    "            # 分割后的加权MSE\n",
    "            MSE_after = len(data_left)/n_sample*MSE(data_left) + \\\n",
    "                len(data_right)/n_sample*MSE(data_right)\n",
    "            gain = MSE_before-MSE_after    # MSE的减小量为增益\n",
    "\n",
    "            # 分裂后的增益小于阈值或小于目前最大增益则放弃分裂\n",
    "            if gain < min_impurity_decrease or gain < best_gain:\n",
    "                continue\n",
    "            else:\n",
    "                # 否则更新最大增益\n",
    "                best_gain = gain\n",
    "                best_f_idx, best_f_val = f_idx, f_val\n",
    "\n",
    "    # 返回一个最佳分割特征与最佳分割点，注意会有空的情况\n",
    "    return best_f_idx, best_f_val\n",
    "\n",
    "# Test(training_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最后就可以使用递归来生成树了。树中每一个节点需要保存的信息有：分割特征，分割点，以及左右分支。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def CART(data,criteria='mse',min_samples_split=5,min_samples_leaf=5,min_impurity_decrease=0.0):\n",
    "    # 首先是做test，数据集的质量由Test函数来保证并提供反馈\n",
    "    best_f_idx,best_f_val=Test(data,criteria,min_samples_split,min_samples_leaf,min_impurity_decrease)\n",
    "    \n",
    "    tree={}\n",
    "    tree['cut_f']=best_f_idx\n",
    "    tree['cut_val']=best_f_val\n",
    "    \n",
    "    if best_f_idx==None:    # f_idx为空表示需要生成叶节点\n",
    "        return best_f_val\n",
    "    \n",
    "    data_left,data_right=BinSplitData(data,best_f_idx,best_f_val)\n",
    "    tree['left']=CART(data_left,criteria,min_samples_split,min_samples_leaf,min_impurity_decrease)\n",
    "    tree['right']=CART(data_right,criteria,min_samples_split,min_samples_leaf,min_impurity_decrease)\n",
    "    \n",
    "    return tree\n",
    "\n",
    "tree=CART(training_data)\n",
    "# print(tree)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MSE:15.101454048749126\n"
     ]
    }
   ],
   "source": [
    "def predict_one(x_test, tree, default=-1):\n",
    "    if isinstance(tree, dict):    # 非叶节点才做左右判断\n",
    "        cut_f_idx, cut_val = tree['cut_f'], tree['cut_val']\n",
    "        sub_tree = tree['left'] if x_test[cut_f_idx] <= cut_val else tree['right']\n",
    "        return predict_one(x_test, sub_tree)\n",
    "    else:    # 叶节点则直接返回值\n",
    "        return tree\n",
    "    \n",
    "# test_idx=100\n",
    "# print(predict_one(X_test[test_idx],tree),Y_test[test_idx])\n",
    "    \n",
    "def predict(X_test,tree):\n",
    "    return [predict_one(x_test,tree) for x_test in X_test]\n",
    "    \n",
    "Y_pred=predict(X_test,tree)\n",
    "print('MSE:{}'.format(np.mean(np.square(Y_pred-Y_test))))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用sklearn中的回归树来做效果对比。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MSE:13.951683170768305\n"
     ]
    }
   ],
   "source": [
    "from sklearn.tree import DecisionTreeRegressor\n",
    "dt_reg=DecisionTreeRegressor(min_samples_split=5, min_samples_leaf=5)\n",
    "dt_reg.fit(X_train,Y_train)\n",
    "Y_pred=dt_reg.predict(X_test)\n",
    "print('MSE:{}'.format(np.mean(np.square(Y_pred-Y_test))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "hide_input": false,
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
